{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hide\n",
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from functools import partial\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from model_constructor.layers import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# model_constructor\n",
    "\n",
    "> Constructor to create pytorch model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`pip install model-constructor`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Or instll from repo:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`pip install git+https://github.com/ayasyrev/model_constructor.git`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to use"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It can be used two ways.  \n",
    "Classic - create model from function with parameters.   \n",
    "And by creating constructor object, then modify it and then create model.\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Class Net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First import constructor class, then create model constructor oject."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model_constructor.net import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       " constr Net"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we have model consructor, defoult setting as xresnet18. And we can get model after call it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.c_in"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1000"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.c_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[3, 32, 32, 64]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.stem_sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[2, 2, 2, 2]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.expansion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  model Net\n",
       "  (stem): Sequential(\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (conv_1): ConvLayer(\n",
       "      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (conv_2): ConvLayer(\n",
       "      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (stem_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (body): Sequential(\n",
       "    (l_0): Sequential(\n",
       "      (bl_0): ResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (bl_1): ResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (l_1): Sequential(\n",
       "      (bl_0): ResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (bl_1): ResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (l_2): Sequential(\n",
       "      (bl_0): ResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (bl_1): ResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (l_3): Sequential(\n",
       "      (bl_0): ResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (bl_1): ResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (head): Sequential(\n",
       "    (pool): AdaptiveAvgPool2d(output_size=1)\n",
       "    (flat): Flatten()\n",
       "    (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you want to change model, just change constructor parameters.  \n",
    "Lets create xresnet50."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.expansion = 4\n",
    "model.layers = [3,4,6,3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can look at model body and if we call constructor - we have pytorch model!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (l_0): Sequential(\n",
       "    (bl_0): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (bl_1): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (bl_2): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (l_1): Sequential(\n",
       "    (bl_0): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "      (idconv): ConvLayer(\n",
       "        (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (bl_1): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (bl_2): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (bl_3): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (l_2): Sequential(\n",
       "    (bl_0): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "      (idconv): ConvLayer(\n",
       "        (conv): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (bl_1): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (bl_2): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (bl_3): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (bl_4): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (bl_5): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (l_3): Sequential(\n",
       "    (bl_0): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "      (idconv): ConvLayer(\n",
       "        (conv): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (bl_1): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (bl_2): ResBlock(\n",
       "      (convs): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.body"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 1000])\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "bs_test = 16\n",
    "xb = torch.randn(bs_test, 3, 128, 128)\n",
    "y = model()(xb)\n",
    "print(y.shape)\n",
    "assert y.shape == torch.Size([bs_test, 1000]), f\"size expected {bs_test}, 1000\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[16, 64, 128, 256, 512]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.block_szs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## More modification."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Main purpose of this module - fast and easy modify model.\n",
    "And here is the link to more modification to beat Imagenette leaderboard with add MaxBlurPool and modification to ResBlock https://github.com/ayasyrev/imagenette_experiments/blob/master/ResnetTrick_create_model_fit.ipynb  \n",
    "\n",
    "But now lets create model as mxresnet50 from fastai forums tread https://forums.fast.ai/t/how-we-beat-the-5-epoch-imagewoof-leaderboard-score-some-new-techniques-to-consider  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets create mxresnet constructor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mxresnet = Net()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then lets modify stem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mxresnet.stem_sizes = [3,32,64,64]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now lets change activation function to Mish.\n",
    "Here is link to forum disscussion https://forums.fast.ai/t/meet-mish-new-activation-function-possible-successor-to-relu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hide\n",
    "import torch.nn as nn\n",
    "# import torch.functional as F\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Mish(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def forward(self, x):  \n",
    "        return x *( torch.tanh(F.softplus(x))) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mxresnet.expansion = 4\n",
    "mxresnet.layers = [3,4,6,3]\n",
    "mxresnet.act_fn = Mish()\n",
    "mxresnet.name = 'mxresnet50'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we have mxresnet50 constructor.  \n",
    "We can inspect some parts of it.  \n",
    "And after call it we got model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       " constr mxresnet50"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mxresnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ConvLayer(\n",
       "  (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "  (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (act_fn): Mish()\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mxresnet.stem.conv_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResBlock(\n",
       "  (convs): Sequential(\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): Mish()\n",
       "    )\n",
       "    (conv_1): ConvLayer(\n",
       "      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): Mish()\n",
       "    )\n",
       "    (conv_2): ConvLayer(\n",
       "      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (idconv): ConvLayer(\n",
       "    (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (act_fn): Mish()\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mxresnet.body.l_0.bl_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 1000])\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "bs_test = 16\n",
    "xb = torch.randn(bs_test, 3, 128, 128)\n",
    "y = mxresnet()(xb)\n",
    "print(y.shape)\n",
    "assert y.shape == torch.Size([bs_test, 1000]), f\"size expected {bs_test}, 1000\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now lets change Resblock. NewResBlock (stiil not own name yet) is in lib from version 0.1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mxresnet.block = NewResBlock"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That all. Let see what we have."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NewResBlock(\n",
       "  (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "  (convs): Sequential(\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): Mish()\n",
       "    )\n",
       "    (conv_1): ConvLayer(\n",
       "      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): Mish()\n",
       "    )\n",
       "    (conv_2): ConvLayer(\n",
       "      (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (idconv): ConvLayer(\n",
       "    (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (merge): Mish()\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mxresnet.body.l_1.bl_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 1000])\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "bs_test = 16\n",
    "xb = torch.randn(bs_test, 3, 128, 128)\n",
    "y = mxresnet()(xb)\n",
    "print(y.shape)\n",
    "assert y.shape == torch.Size([bs_test, 1000]), f\"size expected {bs_test}, 1000\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classic way"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Usual way to get model - call constructor with parametrs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model_constructor.constructor import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Default is resnet18."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (stem): Stem(\n",
       "    sizes: [3, 64]\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (body): Body(\n",
       "    (layer_0): BasicLayer(\n",
       "      from 64 to 64, 2 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_1): BasicLayer(\n",
       "      from 64 to 128, 2 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_2): BasicLayer(\n",
       "      from 128 to 256, 2 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_3): BasicLayer(\n",
       "      from 256 to 512, 2 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (head): Head(\n",
       "    (pool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "    (flat): Flatten()\n",
       "    (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You cant modify model after call constructor, so define model with parameters.   \n",
    "For example, resnet34:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet34 = Net(block=BasicBlock, blocks=[3, 4, 6, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (stem): Stem(\n",
       "    sizes: [3, 64]\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (body): Body(\n",
       "    (layer_0): BasicLayer(\n",
       "      from 64 to 64, 3 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_1): BasicLayer(\n",
       "      from 64 to 128, 4 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_3): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_2): BasicLayer(\n",
       "      from 128 to 256, 6 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_3): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_4): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_5): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_3): BasicLayer(\n",
       "      from 256 to 512, 3 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (head): Head(\n",
       "    (pool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "    (flat): Flatten()\n",
       "    (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide\n",
    "resnet34"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predefined Resnet models - 18, 34, 50."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model_constructor.resnet import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = resnet34(num_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (stem): Stem(\n",
       "    sizes: [3, 64]\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (body): Body(\n",
       "    (layer_0): BasicLayer(\n",
       "      from 64 to 64, 3 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_1): BasicLayer(\n",
       "      from 64 to 128, 4 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_3): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_2): BasicLayer(\n",
       "      from 128 to 256, 6 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_3): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_4): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_5): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_3): BasicLayer(\n",
       "      from 256 to 512, 3 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (head): Head(\n",
       "    (pool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "    (flat): Flatten()\n",
       "    (fc): Linear(in_features=512, out_features=10, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = resnet50(num_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (stem): Stem(\n",
       "    sizes: [3, 64]\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (body): Body(\n",
       "    (layer_0): BasicLayer(\n",
       "      from 64 to 64, 3 blocks, expansion 4.\n",
       "      (block_0): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_1): BasicLayer(\n",
       "      from 256 to 128, 4 blocks, expansion 4.\n",
       "      (block_0): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_3): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_2): BasicLayer(\n",
       "      from 512 to 256, 6 blocks, expansion 4.\n",
       "      (block_0): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_3): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_4): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_5): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_3): BasicLayer(\n",
       "      from 1024 to 512, 3 blocks, expansion 4.\n",
       "      (block_0): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): Bottleneck(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (head): Head(\n",
       "    (pool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "    (flat): Flatten()\n",
       "    (fc): Linear(in_features=2048, out_features=10, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predefined Xresnet from fastai 1."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This ie simplified version from fastai v1. I did refactoring for better understand and experiment with models. For example, it's very simple to change activation funtions, different stems, batchnorm and activation order etc. In v2 much powerfull realisation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model_constructor.xresnet import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = xresnet50()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (stem): Stem(\n",
       "    sizes: [3, 32, 32, 64]\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (conv_1): ConvLayer(\n",
       "      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (conv_2): ConvLayer(\n",
       "      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (body): Body(\n",
       "    (layer_0): BasicLayer(\n",
       "      from 64 to 64, 3 blocks, expansion 4.\n",
       "      (block_0): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_1): BasicLayer(\n",
       "      from 256 to 128, 4 blocks, expansion 4.\n",
       "      (block_0): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_3): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_2): BasicLayer(\n",
       "      from 512 to 256, 6 blocks, expansion 4.\n",
       "      (block_0): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_3): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_4): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_5): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_3): BasicLayer(\n",
       "      from 1024 to 512, 3 blocks, expansion 4.\n",
       "      (block_0): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): XResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (head): Head(\n",
       "    (pool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "    (flat): Flatten()\n",
       "    (fc): Linear(in_features=2048, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Some examples."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can experiment with models by changing some parts of model. Here only base functionality, but it can be easily extanded."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is some examples:\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Custom stem"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Stem with 3 conv layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net(stem=partial(Stem, stem_sizes=[32, 32]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Stem(\n",
       "  sizes: [3, 32, 32, 64]\n",
       "  (conv_0): ConvLayer(\n",
       "    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (conv_1): ConvLayer(\n",
       "    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (conv_2): ConvLayer(\n",
       "    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.stem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net(stem_sizes=[32, 64])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Stem(\n",
       "  sizes: [3, 32, 64, 64]\n",
       "  (conv_0): ConvLayer(\n",
       "    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (conv_1): ConvLayer(\n",
       "    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (conv_2): ConvLayer(\n",
       "    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.stem"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Activation function before Normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net(bn_1st=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Stem(\n",
       "  sizes: [3, 64]\n",
       "  (conv_0): ConvLayer(\n",
       "    (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.stem"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### Change activation function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hide\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_act_fn = nn.LeakyReLU(inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net(act_fn=new_act_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Stem(\n",
       "  sizes: [3, 64]\n",
       "  (conv_0): ConvLayer(\n",
       "    (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "  )\n",
       "  (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.stem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BasicBlock(\n",
       "  (conv): Sequential(\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "    )\n",
       "    (conv_1): ConvLayer(\n",
       "      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (merge): Noop()\n",
       "  (act_conn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.body.layer_0.block_0"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
